# Sampling

Learn how to implement MCP clients that can handle sampling requests from servers, enabling bidirectional communication where clients provide LLM capabilities to servers.

## Overview

Sampling allows MCP clients to respond to LLM completion requests from servers. When a server needs to generate content, answer questions, or perform reasoning tasks, it can send a sampling request to the client, which then processes it using an LLM and returns the result.

## Implementing a Sampling Handler

Create a sampling handler by implementing the `SamplingHandler` interface:

```go
package main

import (
    "context"
    "fmt"
    
    "github.com/mark3labs/mcp-go/client"
    "github.com/mark3labs/mcp-go/client/transport"
    "github.com/mark3labs/mcp-go/mcp"
)

type MySamplingHandler struct {
    // Add fields for your LLM client (OpenAI, Anthropic, etc.)
}

func (h *MySamplingHandler) CreateMessage(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) {
    // Extract request parameters
    messages := request.Messages
    systemPrompt := request.SystemPrompt
    maxTokens := request.MaxTokens
    temperature := request.Temperature
    
    // Process with your LLM
    response, err := h.callLLM(ctx, messages, systemPrompt, maxTokens, temperature)
    if err != nil {
        return nil, fmt.Errorf("LLM call failed: %w", err)
    }
    
    // Return MCP-formatted result
    return &mcp.CreateMessageResult{
        Model: "your-model-name",
        Role:  mcp.RoleAssistant,
        Content: mcp.TextContent{
            Type: "text",
            Text: response,
        },
        StopReason: "endTurn",
    }, nil
}

func (h *MySamplingHandler) callLLM(ctx context.Context, messages []mcp.SamplingMessage, systemPrompt string, maxTokens int, temperature float64) (string, error) {
    // Implement your LLM integration here
    // This is where you'd call OpenAI, Anthropic, or other LLM APIs
    return "Your LLM response here", nil
}
```

## Configuring the Client

Enable sampling by providing a handler when creating the client:

```go
func main() {
    // Create sampling handler
    samplingHandler := &MySamplingHandler{}
    
    // Create stdio transport
    stdioTransport := transport.NewStdio("/path/to/mcp/server", nil)
    
    // Create client with sampling support
    mcpClient := client.NewClient(stdioTransport, client.WithSamplingHandler(samplingHandler))
    
    // Start the client
    ctx := context.Background()
    if err := mcpClient.Start(ctx); err != nil {
        log.Fatalf("Failed to start client: %v", err)
    }
    defer mcpClient.Close()
    if err := mcpClient.Connect(ctx); err != nil {
        log.Fatalf("Failed to connect: %v", err)
    }
    
    // The client will now automatically handle sampling requests
    // from the server using your handler
}
```

## Mock Implementation Example

Here's a complete mock implementation for testing:

```go
package main

import (
import (
    "context"
    "fmt"
    "log"
    "strings"
    "os"

    "github.com/mark3labs/mcp-go/client"
    "github.com/mark3labs/mcp-go/mcp"
)
type MockSamplingHandler struct{}

func (h *MockSamplingHandler) CreateMessage(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) {
    // Log the request for debugging
    log.Printf("Mock LLM received sampling request:")
    log.Printf("  System prompt: %s", request.SystemPrompt)
    log.Printf("  Max tokens: %d", request.MaxTokens)
    log.Printf("  Temperature: %f", request.Temperature)
    
    // Extract the user's message
    var userMessage string
    for _, msg := range request.Messages {
        if msg.Role == mcp.RoleUser {
            if textContent, ok := msg.Content.(mcp.TextContent); ok {
                userMessage = textContent.Text
                log.Printf("  User message: %s", userMessage)
                break
            }
        }
    }
    
    // Generate a mock response
    mockResponse := fmt.Sprintf(
        "Mock LLM response to: '%s'. This is a simulated response from a mock LLM handler.",
        userMessage,
    )
    
    return &mcp.CreateMessageResult{
        Model: "mock-llm-v1",
        Role:  mcp.RoleAssistant,
        Content: mcp.TextContent{
            Type: "text",
            Text: mockResponse,
        },
        StopReason: "endTurn",
    }, nil
}

func main() {
    if len(os.Args) < 2 {
        log.Fatal("Usage: sampling_client <server_path>")
    }
    
    serverPath := os.Args[1]
    
    // Create stdio transport
    stdioTransport := transport.NewStdio(serverPath, nil)
    
    // Create client with mock sampling handler
    mcpClient := client.NewClient(stdioTransport, client.WithSamplingHandler(&MockSamplingHandler{}))
    
    // Start the client
    ctx := context.Background()
    if err := mcpClient.Start(ctx); err != nil {
        log.Fatalf("Failed to start client: %v", err)
    }
    defer mcpClient.Close()
    if err := mcpClient.Connect(ctx); err != nil {
        log.Fatalf("Failed to connect: %v", err)
    }
    
    // Test server tools that use sampling
    result, err := mcpClient.CallTool(ctx, "ask_llm", map[string]any{
        "question": "What is the capital of France?",
        "system_prompt": "You are a helpful geography assistant.",
    })
    if err != nil {
        log.Fatalf("Tool call failed: %v", err)
    }
    
    fmt.Printf("Tool result: %+v\\n", result)
}
```

## Real LLM Integration

### OpenAI Integration

```go
import (
    "github.com/sashabaranov/go-openai"
)

type OpenAISamplingHandler struct {
    client *openai.Client
}

func NewOpenAISamplingHandler(apiKey string) *OpenAISamplingHandler {
    return &OpenAISamplingHandler{
        client: openai.NewClient(apiKey),
    }
}

func (h *OpenAISamplingHandler) CreateMessage(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) {
    // Convert MCP messages to OpenAI format
    var messages []openai.ChatCompletionMessage
    
    // Add system message if provided
    if request.SystemPrompt != "" {
        messages = append(messages, openai.ChatCompletionMessage{
            Role:    openai.ChatMessageRoleSystem,
            Content: request.SystemPrompt,
        })
    }
    
    // Convert MCP messages
    for _, msg := range request.Messages {
        var role string
        switch msg.Role {
        case mcp.RoleUser:
            role = openai.ChatMessageRoleUser
        case mcp.RoleAssistant:
            role = openai.ChatMessageRoleAssistant
        }
        
        if textContent, ok := msg.Content.(mcp.TextContent); ok {
            messages = append(messages, openai.ChatCompletionMessage{
                Role:    role,
                Content: textContent.Text,
            })
        }
    }
    
    // Create OpenAI request
    req := openai.ChatCompletionRequest{
        Model:       openai.GPT3Dot5Turbo,
        Messages:    messages,
        MaxTokens:   request.MaxTokens,
        Temperature: float32(request.Temperature),
    }
    
    // Call OpenAI API
    resp, err := h.client.CreateChatCompletion(ctx, req)
    if err != nil {
        return nil, fmt.Errorf("OpenAI API call failed: %w", err)
    }
    
    if len(resp.Choices) == 0 {
        return nil, fmt.Errorf("no response from OpenAI")
    }
    
    choice := resp.Choices[0]
    
    // Convert stop reason
    var stopReason string
    switch choice.FinishReason {
    case "stop":
        stopReason = "endTurn"
    case "length":
        stopReason = "maxTokens"
    default:
        stopReason = "other"
    }
    
    return &mcp.CreateMessageResult{
        Model: resp.Model,
        Role:  mcp.RoleAssistant,
        Content: mcp.TextContent{
            Type: "text",
            Text: choice.Message.Content,
        },
        StopReason: stopReason,
    }, nil
}
```

### Anthropic Integration

```go
import (
    "bytes"
    "encoding/json"
    "net/http"
)

type AnthropicSamplingHandler struct {
    apiKey string
    client *http.Client
}

func NewAnthropicSamplingHandler(apiKey string) *AnthropicSamplingHandler {
    return &AnthropicSamplingHandler{
        apiKey: apiKey,
        client: &http.Client{},
    }
}

func (h *AnthropicSamplingHandler) CreateMessage(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) {
    // Convert to Anthropic format
    anthropicReq := map[string]any{
        "model":      "claude-3-sonnet-20240229",
        "max_tokens": request.MaxTokens,
        "messages":   h.convertMessages(request.Messages),
    }
    
    if request.SystemPrompt != "" {
        anthropicReq["system"] = request.SystemPrompt
    }
    
    if request.Temperature > 0 {
        anthropicReq["temperature"] = request.Temperature
    }
    
    // Make API call
    reqBody, _ := json.Marshal(anthropicReq)
    httpReq, _ := http.NewRequestWithContext(ctx, "POST", 
        "https://api.anthropic.com/v1/messages", bytes.NewBuffer(reqBody))
    
    httpReq.Header.Set("Content-Type", "application/json")
    httpReq.Header.Set("x-api-key", h.apiKey)
    httpReq.Header.Set("anthropic-version", "2023-06-01")
    
    resp, err := h.client.Do(httpReq)
    if err != nil {
        return nil, fmt.Errorf("Anthropic API call failed: %w", err)
    }
    defer resp.Body.Close()
    
    var anthropicResp struct {
        Content []struct {
            Text string `json:"text"`
            Type string `json:"type"`
        } `json:"content"`
        Model      string `json:"model"`
        StopReason string `json:"stop_reason"`
    }
    
    if err := json.NewDecoder(resp.Body).Decode(&anthropicResp); err != nil {
        return nil, fmt.Errorf("failed to decode Anthropic response: %w", err)
    }
    
    // Extract text content
    var text string
    for _, content := range anthropicResp.Content {
        if content.Type == "text" {
            text += content.Text
        }
    }
    
    return &mcp.CreateMessageResult{
        Model: anthropicResp.Model,
        Role:  mcp.RoleAssistant,
        Content: mcp.TextContent{
            Type: "text",
            Text: text,
        },
        StopReason: anthropicResp.StopReason,
    }, nil
}

func (h *AnthropicSamplingHandler) convertMessages(messages []mcp.SamplingMessage) []map[string]any {
    var result []map[string]any
    for _, msg := range messages {
        if textContent, ok := msg.Content.(mcp.TextContent); ok {
            result = append(result, map[string]any{
                "role":    string(msg.Role),
                "content": textContent.Text,
            })
        }
    }
    return result
}
```

## Automatic Capability Declaration

When you provide a sampling handler, the client automatically declares the sampling capability during initialization:

```go
// This automatically adds sampling capability
stdioTransport := transport.NewStdio(serverPath, nil)
mcpClient := client.NewClient(stdioTransport, client.WithSamplingHandler(handler))
```

The client will include this in the initialization request:

```json
{
  "jsonrpc": "2.0",
  "id": 1,
  "method": "initialize",
  "params": {
    "protocolVersion": "2024-11-05",
    "capabilities": {
      "sampling": {}
    },
    "clientInfo": {
      "name": "your-client",
      "version": "1.0.0"
    }
  }
}
```

## Error Handling

Handle errors gracefully in your sampling handler:

```go
func (h *MySamplingHandler) CreateMessage(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) {
    // Validate request
    if len(request.Messages) == 0 {
        return nil, fmt.Errorf("no messages provided")
    }
    
    // Check for context cancellation
    if err := ctx.Err(); err != nil {
        return nil, fmt.Errorf("request cancelled: %w", err)
    }
    
    // Call LLM with error handling
    response, err := h.callLLM(ctx, request)
    if err != nil {
        // Log error for debugging
        log.Printf("LLM call failed: %v", err)
        
        // Return appropriate error
        if strings.Contains(err.Error(), "rate limit") {
            return nil, fmt.Errorf("rate limit exceeded, please try again later")
        }
        return nil, fmt.Errorf("LLM service unavailable: %w", err)
    }
    
    return response, nil
}
```

## Best Practices

1. **Implement Proper Error Handling**: Always handle LLM API errors gracefully
2. **Respect Rate Limits**: Implement rate limiting and backoff strategies
3. **Validate Inputs**: Check message content and parameters before processing
4. **Use Context**: Respect context cancellation and timeouts
5. **Log Appropriately**: Log requests for debugging but avoid logging sensitive data
6. **Model Selection**: Allow configuration of which LLM model to use
7. **Content Filtering**: Implement content filtering if required by your use case

## Testing Your Implementation

Test your sampling handler with the sampling server example:

```bash
# Build the sampling server
cd examples/sampling_server
go build -o sampling_server

# Build your client
go build -o my_client

# Test the integration
./my_client ./sampling_server
```

## Transport Support

Sampling is available on the following transports:

### STDIO Transport

For STDIO clients, create the transport and client separately:

```go
stdioTransport := transport.NewStdio("/path/to/server", nil)
mcpClient := client.NewClient(stdioTransport, client.WithSamplingHandler(&MySamplingHandler{}))
```

### In-Process Transport

For in-process clients, use the dedicated constructor:

```go
mcpClient, err := client.NewInProcessClientWithSamplingHandler(
    mcpServer, 
    &MySamplingHandler{},
)
```

In-process sampling uses direct method calls instead of JSON-RPC serialization.

### Unsupported Transports

SSE and StreamableHTTP transports do not support sampling due to their one-way or stateless nature.

## Next Steps

- Learn about [server-side sampling implementation](/servers/advanced-sampling)
- Explore [client operations](/clients/operations)
- Check out the [sampling examples](https://github.com/mark3labs/mcp-go/tree/main/examples/sampling_client)
- See [in-process transport documentation](/transports/inprocess) for embedded scenarios